{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import time\n",
    "import datetime\n",
    "import pickle\n",
    "import sqlite3\n",
    "import re\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pyplot import imshow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path on rte\n",
    "db_path = os.path.expanduser(\"~/data/db/arxiv_db_images.sqlite3\")\n",
    "\n",
    "# Here we import the sqlite3 database and create a cursor\n",
    "db = sqlite3.connect(db_path)\n",
    "c = db.cursor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras.preprocessing import image\n",
    "from keras.applications.imagenet_utils import decode_predictions, preprocess_input\n",
    "from keras.models import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this seems to help with some GPU memory issues\n",
    "\n",
    "from tensorflow.compat.v1 import ConfigProto\n",
    "from tensorflow.compat.v1 import InteractiveSession\n",
    "\n",
    "config = ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "session = InteractiveSession(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.applications.VGG16(weights='imagenet', include_top=True)\n",
    "print(\"model loaded\")\n",
    "model.summary()\n",
    "\n",
    "# set up the feature extractor\n",
    "\n",
    "feat_extractor = Model(inputs=model.input, outputs=model.get_layer(\"fc2\").output)\n",
    "print(\"feature extractor setup\")\n",
    "feat_extractor.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_image(path):\n",
    "    img = image.load_img(path, target_size=model.input_shape[1:3])\n",
    "    x = image.img_to_array(img)\n",
    "    x = np.expand_dims(x, axis=0)\n",
    "    x = preprocess_input(x)\n",
    "    return img, x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select images from SQLite according to category/date range\n",
    "\n",
    "Run one of the below blocks and then jump down to next section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Get all images from cs.AI from 2012 (whole year)\n",
    "\n",
    "name = \"csAI_all2012\"\n",
    "\n",
    "sql = (\"SELECT metadata.cat, images.path, images.filename, images.identifier, metadata.created, images.id, metadata.id \"\n",
    "    \"FROM images \"\n",
    "    \"LEFT JOIN metadata ON images.identifier = metadata.identifier \"\n",
    "    \"WHERE images.x != '' \"\n",
    "    \"AND metadata.created BETWEEN date('2012-01-01') \"\n",
    "    \"AND date('2012-12-31') \"\n",
    "    \"AND substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1) = ?\")\n",
    "\n",
    "c.execute(sql, (\"cs.AI\", ))\n",
    "rows = c.fetchall()\n",
    "print(len(rows))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get images from stat.ML from 2012 (whole year)\n",
    "\n",
    "name = \"statML_all2012\"\n",
    "\n",
    "sql = (\"SELECT metadata.cat, images.path, images.filename, images.identifier, metadata.created, images.id, metadata.id \"\n",
    "    \"FROM images \"\n",
    "    \"LEFT JOIN metadata ON images.identifier = metadata.identifier \"\n",
    "    \"WHERE images.x != '' \"\n",
    "    \"AND metadata.created BETWEEN date('2012-01-01') \"\n",
    "    \"AND date('2012-12-31') \"\n",
    "    \"AND substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1) = ? \")\n",
    "\n",
    "c.execute(sql, (\"stat.ML\", ))\n",
    "rows = c.fetchall()\n",
    "print(len(rows))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate t-SNE\n",
    "\n",
    "Find all image files, run classifier over images, optionally save features, then run t-SNE over features and save image map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = []\n",
    "# all_images_path = \"/mnt/hd2/images/all/\"\n",
    "all_images_path = os.path.expanduser(\"~/all/\") # symlink to folder with all images\n",
    "\n",
    "for row in rows:\n",
    "    paths.append(all_images_path + str(row[5]) + \".jpg\")\n",
    "print(paths[:3])\n",
    "print(len(paths))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run the classifier over all images\n",
    "\n",
    "num_x = len(paths)\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "features = []\n",
    "for i, image_path in enumerate(paths):\n",
    "    if i % 500 == 0:\n",
    "#         toc = time.clock()\n",
    "        elap = time.time() - start;\n",
    "        print(\"analyzing image %d / %d. Time: %4.4f seconds.\" % (i, len(paths),elap))\n",
    "        start = time.time()\n",
    "    img, x = load_image(image_path)\n",
    "\n",
    "    feat = feat_extractor.predict(x)[0]\n",
    "    features.append(feat)\n",
    "\n",
    "print('finished extracting features for %d images' % len(paths))\n",
    "\n",
    "# write images, features to a pickle file\n",
    "pickle_file = \"features_\" + name + \"_vgg_x\" + str(num_x) + \".pickle\"\n",
    "print(pickle_file)\n",
    "\n",
    "# WRITE\n",
    "with open(pickle_file, \"wb\") as write_file:\n",
    "    pickle.dump([paths, features], write_file)\n",
    "    write_file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set parameters for t-SNE (these will be written to the filename)\n",
    "perp = 50\n",
    "bPCA = True\n",
    "num_iterations = 2000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run t-SNE\n",
    "\n",
    "Note that this involves random processes and so will not give the same output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read features and then run t-SNE\n",
    "\n",
    "with open(pickle_file, \"rb\") as read_file:\n",
    "    images, features = pickle.load(read_file)\n",
    "    read_file.close()\n",
    "\n",
    "# check that we still have the features and list of images\n",
    "print(\"----- checking images and features -----\")\n",
    "print(\"length of images: \" + str(len(images)))\n",
    "print(\"length of features: \" + str(len(features)))\n",
    "for img, f in list(zip(images, features))[0:5]:\n",
    "    print(\"image: %s, features: %0.2f,%0.2f,%0.2f,%0.2f... \"%(img, f[0], f[1], f[2], f[3]))\n",
    "\n",
    "if len(images) >= 300:\n",
    "    features = np.array(features)\n",
    "    print(\"----- running pca across features -----\")\n",
    "    pca = PCA(n_components=300)\n",
    "    pca.fit(features)\n",
    "\n",
    "    pca_features = pca.transform(features)\n",
    "\n",
    "    X = np.array(pca_features)\n",
    "    X.shape\n",
    "    tsne = TSNE(n_components=2, learning_rate=150, perplexity=perp, \n",
    "                angle=0.2, verbose=2, \n",
    "                n_iter=num_iterations, random_state=5).fit_transform(X)\n",
    "\n",
    "    # normalise points\n",
    "    tx, ty = tsne[:,0], tsne[:,1]\n",
    "    tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))\n",
    "    ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))\n",
    "\n",
    "    width = 4000\n",
    "    height = 3000\n",
    "    max_dim = 100\n",
    "\n",
    "    full_image = Image.new('RGBA', (width, height))\n",
    "    for img, x, y in zip(images, tx, ty):\n",
    "        tile = Image.open(img)\n",
    "        rs = max(1, tile.width/max_dim, tile.height/max_dim)\n",
    "        tile = tile.resize((int(tile.width/rs), int(tile.height/rs)), Image.ANTIALIAS)\n",
    "        full_image.paste(tile, (int((width-max_dim)*x), int((height-max_dim)*y)), mask=tile.convert('RGBA'))\n",
    "\n",
    "    plt.figure(figsize = (16,12))\n",
    "    imshow(full_image)\n",
    "\n",
    "    ts = time.time()\n",
    "    st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')\n",
    "    filename = \"tSNE_\" + name + \"_x\" + str(len(images)) + \"_n\" + str(num_iterations) + \"_p\" + str(perp) + \"_\" + st\n",
    "    print(filename)\n",
    "    full_image.save(filename + \".png\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get accreditation for all images used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# go through all of the retrieved SQL rows and format as an accreditation\n",
    "# if html is True, format with <a href=####> for web usage\n",
    "# input = [[author, title, date, identifier], [author, title, date, identifier],...[author, title, date, identifier]]\n",
    "\n",
    "def format_credits(rows, html=False):\n",
    "    accreditations = []\n",
    "\n",
    "    empty_counter = 0\n",
    "    \n",
    "    # grab the details and check each\n",
    "    for i, row in enumerate(rows):\n",
    "        print(i)\n",
    "        if row:\n",
    "            s = row[0]\n",
    "            start = \"['\"\n",
    "            end = \"']\"\n",
    "            # if the author has been written to the SQLite database with parenthesis\n",
    "            if s.find(start) != -1 and s.rfind(end) != -1:\n",
    "                author = s[s.find(start)+len(start):s.rfind(end)-len(end)]\n",
    "            else:\n",
    "                author = s[:]\n",
    "            print(author)\n",
    "\n",
    "            title = row[1]\n",
    "            # replace line breaks and double spaces\n",
    "            title = title.replace(\"\\n\", \"\").replace(\"  \",\" \")\n",
    "            print(title)\n",
    "\n",
    "            date = row[2].split(\"-\")[0]\n",
    "            print(date)\n",
    "\n",
    "            identifier = row[3]\n",
    "            print(identifier)\n",
    "        #     reg_exp = re.compile(\"/[^\\d]\\d{2}[^\\d]/\")\n",
    "            longest_digits = max(re.findall(r'\\d+', identifier), key = len)\n",
    "            print(len(longest_digits))\n",
    "\n",
    "            # if the identifier contains seven consecutive numbers, add a slash\n",
    "            if len(longest_digits) == 7:\n",
    "                print(\"----- regex match -----\")\n",
    "                reverse = identifier[::-1]\n",
    "                print(reverse)\n",
    "                identifier_reverse = reverse[:7] + \"/\" + reverse[7:]\n",
    "                identifier = identifier_reverse[::-1]\n",
    "            else:\n",
    "                # otherwise we can leave the identifier how it is\n",
    "                print(\"----- no match -----\")\n",
    "            print(identifier)\n",
    "            url = \"https://arxiv.org/abs/\" + identifier\n",
    "            print(url)\n",
    "            print(\"*\" * 20)\n",
    "\n",
    "            # format string and append\n",
    "            if html:\n",
    "                fmt_str = '{}: {}, {}, <a href=\"{}\">{}</a>'\n",
    "                accreditations.append(fmt_str.format(author, title, str(date), url, url))  \n",
    "            else:\n",
    "                fmt_str = '{}: {}, {}, {}'\n",
    "                accreditations.append(fmt_str.format(author, title, str(date), url, url))\n",
    "        else:\n",
    "            empty_counter += 1\n",
    "            print(\"empty!\")\n",
    "\n",
    "    print(\"number of empty slots:\", empty_counter)\n",
    "    return accreditations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get the accreditations, find the paper identifiers of all the images we have used. Remove duplicates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_ids = []\n",
    "\n",
    "for row in rows:\n",
    "    meta_ids.append(row[6])\n",
    "\n",
    "print(\"meta_ids length:\",len(meta_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unique(ilist):\n",
    "    ulist = []\n",
    "    for x in ilist:\n",
    "        if x not in ulist:\n",
    "            ulist.append(x)\n",
    "    return ulist\n",
    "\n",
    "meta_ids = unique(meta_ids)\n",
    "print(len(meta_ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = []\n",
    "\n",
    "sql = (\"SELECT metadata.authors, metadata.title, metadata.created, metadata.identifier \"\n",
    "    \"FROM metadata \"\n",
    "    \"WHERE metadata.id is ? \")\n",
    "\n",
    "# iterate over the file_ids list and grab sql data\n",
    "for file_id in meta_ids[:]:\n",
    "    print(\"image parent article id:\",file_id)\n",
    "    c.execute(sql, (file_id, ))\n",
    "    row = c.fetchall()\n",
    "    rows.append(row[0])\n",
    "    print(row[0])\n",
    "#     print(len(row))https://arxiv.org/abs/1201.6583\n",
    "#     print(len(row[0]))\n",
    "#     print(list(rows))\n",
    "\n",
    "accreditations = format_credits(rows, False)\n",
    "\n",
    "os.chdir(os.path.expanduser(\"~/documentation/data-samples/\"))\n",
    "print(os.getcwd())\n",
    "\n",
    "f = open(filename + \".txt\", \"w+\")\n",
    "for row in accreditations:\n",
    "    f.write(str(row) + \"\\n\")\n",
    "f.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
